{ "cells": [ { "cell_type": "markdown", "source": [ "# Stacked Statistic\n", "\n", "In this notebook, we'll show the use of the StackedLinearFractionalStatistic, which is a helpful class to combine multiple linear-fractional statistics. For example, the well-known fairness definition of *equalised odds* enforces equality in both the true positive rate and the false positive rate across sensitive groups. Hence, we can keep track of these statistics in a single vector-valued statistic:" ], "metadata": { "collapsed": false }, "id": "ff3654068006473d" }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "from fairret.statistic import TruePositiveRate, FalsePositiveRate, StackedLinearFractionalStatistic\n", "\n", "equalised_odds_stats = StackedLinearFractionalStatistic(TruePositiveRate(), FalsePositiveRate())" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:14.497622400Z", "start_time": "2024-07-23T15:10:11.499666100Z" } }, "id": "ff837c345568629b" }, { "cell_type": "markdown", "source": [ "Let's quickly try it out..." ], "metadata": { "collapsed": false }, "id": "d160ad2b65016c65" }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "import torch\n", "torch.manual_seed(0)\n", "\n", "feat = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])\n", "sens = torch.tensor([[1., 0.], [1., 0.], [0., 1.], [0., 1.]])\n", "label = torch.tensor([[0.], [1.], [0.], [1.]])\n", "\n", "from fairret.loss import NormLoss\n", "\n", "norm_loss = NormLoss(equalised_odds_stats)\n", "\n", "h_layer_dim = 16\n", "lr = 1e-3\n", "batch_size = 1024\n", "\n", "def build_model():\n", " _model = torch.nn.Sequential(\n", " torch.nn.Linear(feat.shape[1], h_layer_dim),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(h_layer_dim, 1)\n", " )\n", " _optimizer = torch.optim.Adam(_model.parameters(), lr=lr)\n", " return _model, _optimizer\n", "\n", "from torch.utils.data import TensorDataset, DataLoader\n", "dataset = TensorDataset(feat, sens, label)\n", "dataloader = DataLoader(dataset, batch_size=batch_size)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:14.859486100Z", "start_time": "2024-07-23T15:10:14.503150Z" } }, "id": "e7e46c4842e9d9e5" }, { "cell_type": "markdown", "source": [ "Without fairret..." ], "metadata": { "collapsed": false }, "id": "8ba44232a13a288a" }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, loss: 0.7091795206069946\n", "Epoch: 1, loss: 0.7061765193939209\n", "Epoch: 2, loss: 0.7033581733703613\n", "Epoch: 3, loss: 0.7007156610488892\n", "Epoch: 4, loss: 0.6982340812683105\n", "Epoch: 5, loss: 0.6959078907966614\n", "Epoch: 6, loss: 0.6937355995178223\n", "Epoch: 7, loss: 0.6917158365249634\n", "Epoch: 8, loss: 0.6898466944694519\n", "Epoch: 9, loss: 0.6881252527236938\n", "Epoch: 10, loss: 0.6865478754043579\n", "Epoch: 11, loss: 0.6851094961166382\n", "Epoch: 12, loss: 0.6838041543960571\n", "Epoch: 13, loss: 0.6826250553131104\n", "Epoch: 14, loss: 0.6815641522407532\n", "Epoch: 15, loss: 0.6806124448776245\n", "Epoch: 16, loss: 0.6797604560852051\n", "Epoch: 17, loss: 0.6789975762367249\n", "Epoch: 18, loss: 0.6783132553100586\n", "Epoch: 19, loss: 0.6776963472366333\n", "Epoch: 20, loss: 0.6771360039710999\n", "Epoch: 21, loss: 0.6766215562820435\n", "Epoch: 22, loss: 0.6761429309844971\n", "Epoch: 23, loss: 0.6756909489631653\n", "Epoch: 24, loss: 0.6752569675445557\n", "Epoch: 25, loss: 0.6748337745666504\n", "Epoch: 26, loss: 0.674415111541748\n", "Epoch: 27, loss: 0.673996090888977\n", "Epoch: 28, loss: 0.6735726594924927\n", "Epoch: 29, loss: 0.6731564998626709\n", "Epoch: 30, loss: 0.6727579236030579\n", "Epoch: 31, loss: 0.672345757484436\n", "Epoch: 32, loss: 0.6719199419021606\n", "Epoch: 33, loss: 0.6714813709259033\n", "Epoch: 34, loss: 0.6710319519042969\n", "Epoch: 35, loss: 0.6705741882324219\n", "Epoch: 36, loss: 0.6701083779335022\n", "Epoch: 37, loss: 0.669636607170105\n", "Epoch: 38, loss: 0.6691610217094421\n", "Epoch: 39, loss: 0.6686834096908569\n", "Epoch: 40, loss: 0.6682056188583374\n", "Epoch: 41, loss: 0.6677289009094238\n", "Epoch: 42, loss: 0.667254626750946\n", "Epoch: 43, loss: 0.6667835712432861\n", "Epoch: 44, loss: 0.6663164496421814\n", "Epoch: 45, loss: 0.6658533811569214\n", "Epoch: 46, loss: 0.6653945446014404\n", "Epoch: 47, loss: 0.6649397015571594\n", "Epoch: 48, loss: 0.6644884347915649\n", "Epoch: 49, loss: 0.6640403270721436\n", "Epoch: 50, loss: 0.6635947227478027\n", "Epoch: 51, loss: 0.6631510257720947\n", "Epoch: 52, loss: 0.6628453135490417\n", "Epoch: 53, loss: 0.6625917553901672\n", "Epoch: 54, loss: 0.6623181104660034\n", "Epoch: 55, loss: 0.6620256900787354\n", "Epoch: 56, loss: 0.6617173552513123\n", "Epoch: 57, loss: 0.6614043116569519\n", "Epoch: 58, loss: 0.6610796451568604\n", "Epoch: 59, loss: 0.6607442498207092\n", "Epoch: 60, loss: 0.6603990793228149\n", "Epoch: 61, loss: 0.6600450277328491\n", "Epoch: 62, loss: 0.6596829295158386\n", "Epoch: 63, loss: 0.6593135595321655\n", "Epoch: 64, loss: 0.6589376330375671\n", "Epoch: 65, loss: 0.6585558652877808\n", "Epoch: 66, loss: 0.6581688523292542\n", "Epoch: 67, loss: 0.6577771306037903\n", "Epoch: 68, loss: 0.6574320793151855\n", "Epoch: 69, loss: 0.6571431756019592\n", "Epoch: 70, loss: 0.6568371653556824\n", "Epoch: 71, loss: 0.6565203666687012\n", "Epoch: 72, loss: 0.6561905145645142\n", "Epoch: 73, loss: 0.6558488607406616\n", "Epoch: 74, loss: 0.65549635887146\n", "Epoch: 75, loss: 0.6551340818405151\n", "Epoch: 76, loss: 0.6547629237174988\n", "Epoch: 77, loss: 0.6544535160064697\n", "Epoch: 78, loss: 0.6541627645492554\n", "Epoch: 79, loss: 0.6538523435592651\n", "Epoch: 80, loss: 0.6535260677337646\n", "Epoch: 81, loss: 0.6531944274902344\n", "Epoch: 82, loss: 0.6528521776199341\n", "Epoch: 83, loss: 0.6525000333786011\n", "Epoch: 84, loss: 0.652138888835907\n", "Epoch: 85, loss: 0.6518597602844238\n", "Epoch: 86, loss: 0.6515651345252991\n", "Epoch: 87, loss: 0.6512539982795715\n", "Epoch: 88, loss: 0.6509299874305725\n", "Epoch: 89, loss: 0.650594174861908\n", "Epoch: 90, loss: 0.6502466797828674\n", "Epoch: 91, loss: 0.6498894691467285\n", "Epoch: 92, loss: 0.6495950818061829\n", "Epoch: 93, loss: 0.6493034362792969\n", "Epoch: 94, loss: 0.6489962339401245\n", "Epoch: 95, loss: 0.6486777067184448\n", "Epoch: 96, loss: 0.6483432650566101\n", "Epoch: 97, loss: 0.6479994058609009\n", "Epoch: 98, loss: 0.6476455330848694\n", "Epoch: 99, loss: 0.6473514437675476\n", "The TPR and FPR for group 0 are tensor([0.5537, 0.4752], grad_fn=)\n", "The TPR and FPR for group 1 are tensor([0.6604, 0.6083], grad_fn=)\n", "The absolute differences are tensor([0.1067, 0.1331], grad_fn=)\n" ] } ], "source": [ "import numpy as np\n", "\n", "nb_epochs = 100\n", "model, optimizer = build_model()\n", "for epoch in range(nb_epochs):\n", " losses = []\n", " for batch_feat, batch_sens, batch_label in dataloader:\n", " optimizer.zero_grad()\n", " \n", " logit = model(batch_feat)\n", " loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)\n", " loss.backward()\n", " \n", " optimizer.step()\n", " losses.append(loss.item())\n", " print(f\"Epoch: {epoch}, loss: {np.mean(losses)}\")\n", " \n", "pred = torch.sigmoid(model(feat))\n", "eo_per_group = equalised_odds_stats(pred, sens, label)\n", "absolute_diff = torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])\n", "\n", "print(f\"The TPR and FPR for group 0 are {eo_per_group[:, 0]}\")\n", "print(f\"The TPR and FPR for group 1 are {eo_per_group[:, 1]}\")\n", "print(f\"The absolute differences are {torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:15.972016500Z", "start_time": "2024-07-23T15:10:14.852398300Z" } }, "id": "bd73c6dbd11d4106" }, { "cell_type": "markdown", "source": [ "With fairret..." ], "metadata": { "collapsed": false }, "id": "dea6c331b680a941" }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, loss: 0.8069422245025635\n", "Epoch: 1, loss: 0.7932361960411072\n", "Epoch: 2, loss: 0.7793688178062439\n", "Epoch: 3, loss: 0.7653393149375916\n", "Epoch: 4, loss: 0.7511466145515442\n", "Epoch: 5, loss: 0.7367900013923645\n", "Epoch: 6, loss: 0.7222684025764465\n", "Epoch: 7, loss: 0.7075802683830261\n", "Epoch: 8, loss: 0.693701446056366\n", "Epoch: 9, loss: 0.7050990462303162\n", "Epoch: 10, loss: 0.7107114195823669\n", "Epoch: 11, loss: 0.7118827104568481\n", "Epoch: 12, loss: 0.7095987200737\n", "Epoch: 13, loss: 0.704603374004364\n", "Epoch: 14, loss: 0.6974690556526184\n", "Epoch: 15, loss: 0.6966591477394104\n", "Epoch: 16, loss: 0.70110023021698\n", "Epoch: 17, loss: 0.7034574151039124\n", "Epoch: 18, loss: 0.7040092945098877\n", "Epoch: 19, loss: 0.7029833793640137\n", "Epoch: 20, loss: 0.7005670070648193\n", "Epoch: 21, loss: 0.6969154477119446\n", "Epoch: 22, loss: 0.6944364905357361\n", "Epoch: 23, loss: 0.6974213719367981\n", "Epoch: 24, loss: 0.6976830959320068\n", "Epoch: 25, loss: 0.695574164390564\n", "Epoch: 26, loss: 0.6945269703865051\n", "Epoch: 27, loss: 0.6960493922233582\n", "Epoch: 28, loss: 0.6960320472717285\n", "Epoch: 29, loss: 0.6946355700492859\n", "Epoch: 30, loss: 0.6946574449539185\n", "Epoch: 31, loss: 0.6954229474067688\n", "Epoch: 32, loss: 0.6938773989677429\n", "Epoch: 33, loss: 0.6954055428504944\n", "Epoch: 34, loss: 0.6965784430503845\n", "Epoch: 35, loss: 0.6962950229644775\n", "Epoch: 36, loss: 0.6946983933448792\n", "Epoch: 37, loss: 0.6947646141052246\n", "Epoch: 38, loss: 0.6957594752311707\n", "Epoch: 39, loss: 0.6944624781608582\n", "Epoch: 40, loss: 0.6947464942932129\n", "Epoch: 41, loss: 0.6957638263702393\n", "Epoch: 42, loss: 0.6953589916229248\n", "Epoch: 43, loss: 0.6936694383621216\n", "Epoch: 44, loss: 0.6961650252342224\n", "Epoch: 45, loss: 0.6972649693489075\n", "Epoch: 46, loss: 0.6960750222206116\n", "Epoch: 47, loss: 0.6933992505073547\n", "Epoch: 48, loss: 0.6943512558937073\n", "Epoch: 49, loss: 0.6938944458961487\n", "Epoch: 50, loss: 0.6944283843040466\n", "Epoch: 51, loss: 0.6942562460899353\n", "Epoch: 52, loss: 0.6940963864326477\n", "Epoch: 53, loss: 0.6944071054458618\n", "Epoch: 54, loss: 0.693374752998352\n", "Epoch: 55, loss: 0.6957550048828125\n", "Epoch: 56, loss: 0.6961789727210999\n", "Epoch: 57, loss: 0.6943976879119873\n", "Epoch: 58, loss: 0.6950995922088623\n", "Epoch: 59, loss: 0.6964230537414551\n", "Epoch: 60, loss: 0.6963126063346863\n", "Epoch: 61, loss: 0.6949063539505005\n", "Epoch: 62, loss: 0.6942113041877747\n", "Epoch: 63, loss: 0.6950266361236572\n", "Epoch: 64, loss: 0.6936017870903015\n", "Epoch: 65, loss: 0.6954660415649414\n", "Epoch: 66, loss: 0.6965658664703369\n", "Epoch: 67, loss: 0.696255624294281\n", "Epoch: 68, loss: 0.6946702003479004\n", "Epoch: 69, loss: 0.694717288017273\n", "Epoch: 70, loss: 0.6957194209098816\n", "Epoch: 71, loss: 0.6944576501846313\n", "Epoch: 72, loss: 0.6946854591369629\n", "Epoch: 73, loss: 0.6956840753555298\n", "Epoch: 74, loss: 0.6952812671661377\n", "Epoch: 75, loss: 0.6936118006706238\n", "Epoch: 76, loss: 0.6961743235588074\n", "Epoch: 77, loss: 0.6972628831863403\n", "Epoch: 78, loss: 0.6960713267326355\n", "Epoch: 79, loss: 0.6933831572532654\n", "Epoch: 80, loss: 0.6943389773368835\n", "Epoch: 81, loss: 0.6938958168029785\n", "Epoch: 82, loss: 0.6943807601928711\n", "Epoch: 83, loss: 0.6941943168640137\n", "Epoch: 84, loss: 0.6941384673118591\n", "Epoch: 85, loss: 0.6944611668586731\n", "Epoch: 86, loss: 0.6934477686882019\n", "Epoch: 87, loss: 0.6956196427345276\n", "Epoch: 88, loss: 0.6960253715515137\n", "Epoch: 89, loss: 0.6942300200462341\n", "Epoch: 90, loss: 0.6952304840087891\n", "Epoch: 91, loss: 0.6965633630752563\n", "Epoch: 92, loss: 0.6964669823646545\n", "Epoch: 93, loss: 0.6950795650482178\n", "Epoch: 94, loss: 0.693950891494751\n", "Epoch: 95, loss: 0.6947492361068726\n", "Epoch: 96, loss: 0.6933133602142334\n", "Epoch: 97, loss: 0.6956915855407715\n", "Epoch: 98, loss: 0.6967988014221191\n", "Epoch: 99, loss: 0.6964995861053467\n", "The TPR and FPR for group 0 are tensor([0.5012, 0.5009], grad_fn=)\n", "The TPR and FPR for group 1 are tensor([0.5017, 0.5014], grad_fn=)\n", "The absolute differences are tensor([0.0005, 0.0005], grad_fn=)\n" ] } ], "source": [ "import numpy as np\n", "\n", "nb_epochs = 100\n", "fairness_strength = 1\n", "model, optimizer = build_model()\n", "for epoch in range(nb_epochs):\n", " losses = []\n", " for batch_feat, batch_sens, batch_label in dataloader:\n", " optimizer.zero_grad()\n", " \n", " logit = model(batch_feat)\n", " loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)\n", " loss += fairness_strength * norm_loss(logit, batch_sens, batch_label)\n", " loss.backward()\n", " \n", " optimizer.step()\n", " losses.append(loss.item())\n", " print(f\"Epoch: {epoch}, loss: {np.mean(losses)}\")\n", " \n", "pred = torch.sigmoid(model(feat))\n", "eo_per_group = equalised_odds_stats(pred, sens, label)\n", "absolute_diff = torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])\n", "\n", "print(f\"The TPR and FPR for group 0 are {eo_per_group[:, 0]}\")\n", "print(f\"The TPR and FPR for group 1 are {eo_per_group[:, 1]}\")\n", "print(f\"The absolute differences are {torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:16.527198300Z", "start_time": "2024-07-23T15:10:15.955048500Z" } }, "id": "4bc786354387b2af" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }